-
Notifications
You must be signed in to change notification settings - Fork 784
WindowedAttention Initial implementation #4151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- tried to stick to cross attention conventions - still need features like cqa and mqa
Codecov Report❌ Patch coverage is
❌ Your project check has failed because the head coverage (68.48%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #4151 +/- ##
==========================================
+ Coverage 68.42% 68.48% +0.05%
==========================================
Files 1281 1282 +1
Lines 157264 157576 +312
==========================================
+ Hits 107602 107909 +307
- Misses 49662 49667 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hey, I’m working on sliding-window attention for an audio perception project now. I’m planning to implement optimized CubeCL kernels for this ...wanted to coordinate so we’re not duplicating effort. What's the current approach.. |
|
I'm happy to go with your approach rather if you want, then you can reference it and I'll check how it deviates from mine and if mine would add value. I'm not continuing work on it for now. |
|
This PR has been marked as stale because it has not been updated for over a month |
|
Hey @huy209vn any progress? I'd be happy to take it over again 🙏 |
- Add n_heads_kv config option for grouped-query and multi-query attention - Fix dropout to apply after softmax instead of before (on attention weights) - Add reshape_kv and repeat_kv helpers for GQA head expansion - Update ModuleDisplay to show n_heads_kv - Add tests for GQA, MQA, and invalid config assertion
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
#4096 (windowed attention portion)
Changes
Adds windowed self-attention module:
generate_sliding_window_mask()utility inmask.rsWindowedAttentionmodule with causal and bidirectional modesWindowedAttentionCache- rolling KV cache storing only lastwindow_sizepairsStill needed:
cc @huy209vn - I got started, feel free to add as you wish. I'll add as well when I have some time.
Testing
window_size